import typer
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer
from rich import print
from typer import Argument, Option
from typing import Optional


def prepare_data(dataset_save_name: str = Argument(..., help='数据集保存名称'),
                 tokenizer_name_or_path: str = Argument(..., help='tokenizer名称或路径'),
                 cache_dir: str = Argument(..., help='缓存目录，用于保存处理后的数据集'),
                 corpus_dir: str = Argument(..., help='原始语料目录，将会读取目录中所有的jsonl文件'),
                 num_proc: int = Option(20, help='处理进程数，默认为20'),
                 use_prompt: Optional[bool] = Option(None, help='是否使用prompt，默认为None，即使用prompt'),
                 max_length: int = Option(512, help='文本处理最大长度，默认为512')):
    """数据预处理

    """
    
    print('load tokenizer')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
    if not tokenizer.pad_token:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    raw_dir = Path(corpus_dir)
    data_files = []
    for p in raw_dir.iterdir():
        if p.is_file() and str(p).endswith('jsonl'):
            data_files.append(str(p))
            
    print(f'all data files: {data_files}')
    ds = load_dataset('json', data_files=data_files, cache_dir=cache_dir)
    
    def tokenize_with_prompt(examples):
        texts = [pre_prompt + content + post_prompt for pre_prompt, content, post_prompt in zip(examples['pre-prompt'], examples['content'], examples['post-prompt'])]
        inputs = tokenizer(texts, 
                           return_overflowing_tokens=True, 
                           max_length=max_length, 
                           truncation=True)
        return {'input_ids': inputs['input_ids']}
    
    def tokenize_no_prompt(examples):
        texts = [pre_prompt + content + post_prompt for pre_prompt, content, post_prompt in zip(examples['pre-prompt'], examples['content'], examples['post-prompt'])]
        inputs = tokenizer(texts, 
                           return_overflowing_tokens=True, 
                           max_length=max_length, 
                           truncation=True)
        return {'input_ids': inputs['input_ids']}
        
    
    print('preprocess dataset')
    if use_prompt:
        ds = ds.map(tokenize_with_prompt, batched=True, num_proc=num_proc, remove_columns=ds['train'].column_names)
    else:
        ds = ds.map(tokenize_no_prompt, batched=True, num_proc=num_proc, remove_columns=ds['train'].column_names)
    
    print('save dataset')
    save_path = Path(cache_dir, dataset_save_name)
    ds.save_to_disk(save_path)
    
    
if __name__ == "__main__":
    typer.run(prepare_data)